
#include "otp.h"

#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <sys/time.h>
#include <openssl/hmac.h>

// ----------------------
static char *str_skip_prefix(const char *str, const char *prefix) {
	return strncmp(str, prefix, strlen(prefix)) ? NULL : (char *)(str + strlen(prefix));
}
static int hex_decode(const char *hex, char *binary, int *binary_length) {
	int len = strlen(hex);
	if ((len & 0x01) != 0)
		return OTP_ERR;
	len = len >> 1;
	*binary_length = len;
	
	char a, b;
	int i;
	for (i = 0; i < len; i++) {
		a = hex[i << 1];
		b = hex[(i << 1) + 1];
		if (a >= '0' && a <= '9') {
			a = a - '0';
		} else if (a >= 'a' && a <= 'f') {
			a = a - 'a' + 10;
		} else if (a >= 'A' && a <= 'F') {
			a = a - 'A' + 10;
		} else {
			return OTP_ERR;
		}
		if (b >= '0' && b <= '9') {
			b = b - '0';
		} else if (b >= 'a' && b <= 'f') {
			b = b - 'a' + 10;
		} else if (b >= 'A' && b <= 'F') {
			b = b - 'A' + 10;
		} else {
			return OTP_ERR;
		}
		binary[i] = a << 4 | b;
	}
	return OTP_SUCC;
}

static int b64_decode(const char *b64, char *binary, int *binary_length) {
	int len = strlen(b64);
	if (len % 4 != 0)
		return OTP_ERR;
	
	int k = len / 4;
	int p = 0;
	
	int i;
	char c;
	char bb[len];
	for (i = 0; i < len; i++) {
		c = b64[i];
		if (p > 0) {
			if (c != '=' || ++p > 2) {
				return OTP_ERR;
			}
			bb[i] = 0;
			continue;
		}
		if (c >= 'A' && c <= 'Z') {
			bb[i] = (c - 'A' + 0);
		} else if (c >= 'a' && c <= 'z') {
			bb[i] = (c - 'a' + 26);
		} else if (c >= '0' && c <= '9') {
			bb[i] = (c - '0' + 52);
		} else if (c == '+' || c == '-') {
			bb[i] = 62;
		} else if (c == '/' || c == '_') {
			bb[i] = 63;
		} else if (c == '=') {
			bb[i] = 0;
			p++;
		} else {
			return OTP_ERR;
		}
	}
	
	int len2 = k * 3 - p;
	*binary_length = len2;
	
	int j, m, n;
	char *ret = binary;
	for (j = 0, m = 0, n = 0; j < k; j++, m += 3, n+= 4) {
		if (j == k - 1) {
			ret[m + 0] = (bb[n + 0] << 2 | bb[n + 1] >> 4);
			if (m + 1 < len2)
				ret[m + 1] = (bb[n + 1] << 4 | bb[n + 2] >> 2);
			if (m + 2 < len2)
				ret[m + 2] = (bb[n + 2] << 6 | bb[n + 3] >> 0);
		} else {
			ret[m + 0] = (bb[n + 0] << 2 | bb[n + 1] >> 4);
			ret[m + 1] = (bb[n + 1] << 4 | bb[n + 2] >> 2);
			ret[m + 2] = (bb[n + 2] << 6 | bb[n + 3] >> 0);
		}
	}
	return OTP_SUCC;
}

// ----------------------
static void gen_token(const char *secret, int secret_length, long counter, char *token) {
	int i;
	unsigned char data[8];
	for (i = 7; i >= 0; i--) {
		data[i] = counter & 0xFF;
		counter >>= 8;
	}
	
	unsigned char hash[20];
	int hash_len = 20;
	
	HMAC_CTX *p_ctx;
#if OPENSSL_VERSION_NUMBER < 0x10100000L
	HMAC_CTX ctx;
	HMAC_CTX_init(&ctx);
	p_ctx = &ctx;
#else
	p_ctx = HMAC_CTX_new();
#endif
	HMAC_Init_ex(p_ctx, secret, secret_length, EVP_sha1(), NULL);
	HMAC_Update(p_ctx, data, 8);
	HMAC_Final(p_ctx, hash, &hash_len);
#if OPENSSL_VERSION_NUMBER < 0x10100000L
	HMAC_CTX_cleanup(p_ctx);
#else
	HMAC_CTX_free(p_ctx);
#endif
	
	int offset = hash[19] & 0xf;
	unsigned int num = (hash[offset] & 0x7F) << 24 |
		(hash[offset + 1]) << 16 |
		(hash[offset + 2]) << 8  |
		(hash[offset + 3]);
	
	num = num % 1000000;
	sprintf(token, "%06d", num);
}

// ----------------------
int otp_convert_secret(const char *secret, const char *format, char *binary, int *binary_length) {
	if (strcmp(format, "hex") == 0) {
		hex_decode(secret, binary, binary_length);
		return OTP_SUCC;
	} else if (strcmp(format, "base64") == 0) {
		b64_decode(secret, binary, binary_length);
		return OTP_SUCC;
	} else if (strcmp(format, "string") == 0 || format == NULL || strlen(format) == 0) {
		*binary_length = strlen(secret);
		binary[0] = '\0';
		strcpy(binary, secret);
		return OTP_SUCC;
	}
	return OTP_ERR;
}

// ----------------------
void hotp_init(hotp_ctx *ctx, const char *secret, long counter) {
	ctx->secret = secret;
	ctx->secret_length = secret == NULL ? 0 : strlen(secret);
	ctx->counter = counter;
}
void hotp_set_secret(hotp_ctx *ctx, const char *secret, int secret_length) {
	ctx->secret = secret;
	ctx->secret_length = secret_length;
}
void hotp_gen(hotp_ctx *ctx, char *token) {
	hotp_gen2(ctx, token, ctx->counter);
}
void hotp_gen2(hotp_ctx *ctx, char *token, long counter) {
	gen_token(ctx->secret, ctx->secret_length, counter, token);
}
int hotp_verify(hotp_ctx *ctx, const char *token) {
	int t = hotp_verify2(ctx, token, ctx->counter);
	if (t == OTP_SUCC)
		ctx->counter ++;
	return t;
}
int hotp_verify2(hotp_ctx *ctx, const char *token, long counter) {
	unsigned char token2[10];
	gen_token(ctx->secret, ctx->secret_length, counter, token2);
	if (strcmp(token, token2) == 0)
		return OTP_SUCC;
	return OTP_ERR;
}

// ----------------------
void totp_init(totp_ctx *ctx, const char *secret, int window, int interval) {
	ctx->secret = secret;
	ctx->secret_length = secret == NULL ? 0 : strlen(secret);
	ctx->window = window;
	ctx->interval = interval;
}
void totp_set_secret(totp_ctx *ctx, const char *secret, int secret_length) {
	ctx->secret = secret;
	ctx->secret_length = secret_length;
}
void totp_gen(totp_ctx *ctx, char *token) {
	struct timeval tv;
	gettimeofday(&tv, NULL);
	totp_gen2(ctx, token, tv.tv_sec * 1000);
}
void totp_gen2(totp_ctx *ctx, char *token, long timestamp) {
	long counter = timestamp / 1000 / ctx->interval;
	gen_token(ctx->secret, ctx->secret_length, counter, token);
}
int totp_verify(totp_ctx *ctx, const char *token) {
	struct timeval tv;
	gettimeofday(&tv, NULL);
	return totp_verify2(ctx, token, tv.tv_sec * 1000);
}
int totp_verify2(totp_ctx *ctx, const char *token, long timestamp) {
	long counter = timestamp / 1000 / ctx->interval;
	
	int i, w;
	unsigned char token2[10];
	for (i = 0, w = 0; ; i++) {
		gen_token(ctx->secret, ctx->secret_length, counter - i, token2);
		if (strcmp(token, token2) == 0)
			return OTP_SUCC;
		if (++w >= ctx->window)
			break;
		if (i > 0) {
			gen_token(ctx->secret, ctx->secret_length, counter + i, token2);
			if (strcmp(token, token2) == 0)
				return OTP_SUCC;
			if (++w >= ctx->window)
				break;
		}
	}
	return OTP_ERR;
}

// ======================
static int parse_secret(char *src, char *secret, int *secret_length) {
	if (src == NULL)
		return OTP_ERR;
	int len = strlen(src);
	if (len == 0 || len > 50)
		return OTP_ERR;
	
	char *match;
	int retval = OTP_ERR;
	if ((match = str_skip_prefix(src, "hex:")) != NULL) {
		retval = otp_convert_secret(match, "hex", secret, secret_length);
	} else if ((match = str_skip_prefix(src, "b64:")) != NULL
			|| (match = str_skip_prefix(src, "base64:")) != NULL) {
		retval = otp_convert_secret(match, "base64", secret, secret_length);
	} else if ((match = str_skip_prefix(src, "string:")) != NULL) {
		retval = otp_convert_secret(match, "string", secret, secret_length);
	} else {
		retval = otp_convert_secret(src, "string", secret, secret_length);
	}
	return retval;
}

int main(int argc, char *argv[]) {
	unsigned char ret[10];
	int retval;
	
	if (argc == 1 || (argc == 2 && strcmp(argv[1], "help") == 0)) {
		printf("Example:\n");
		printf("    hotp_gen    123456 0\n");
		printf("    hotp_verify 123456 0 888888\n");
		printf("    totp_gen    123456 3 30 [1614054214456]\n");
		printf("    totp_verify 123456 3 30 888888 [1614054214456]\n");
		return 0;
	}
	
	//
	if (argc < 3) {
		printf("[error] param err\n");
		return 1;
	}
	//
	char *func = argv[1];
	char *src_key = argv[2];
	char key[64];
	int keylen;
	if (parse_secret(src_key, key, &keylen) != OTP_SUCC) {
		printf("[error] param err\n");
		return 1;
	}
	
	printf(">> %s\n", func);
	if (strcmp(func, "hotp_gen") == 0 && argc == 4) {
		hotp_ctx ctx;
		hotp_init(&ctx, NULL, atol(argv[3]));
		hotp_set_secret(&ctx, key, keylen);
		hotp_gen(&ctx, ret);
		printf("%s\n", ret);
		return 0;
	} else if (strcmp(func, "totp_gen") == 0 && argc >= 5) {
		totp_ctx ctx;
		totp_init(&ctx, NULL, atoi(argv[3]), atoi(argv[4]));
		totp_set_secret(&ctx, key, keylen);
		if (argc == 5)
			totp_gen(&ctx, ret);
		else
			totp_gen2(&ctx, ret, atol(argv[5]));
		printf("%s\n", ret);
		return 0;
	} else if (strcmp(func, "hotp_verify") == 0 && argc == 5) {
		hotp_ctx ctx;
		hotp_init(&ctx, NULL, atol(argv[3]));
		hotp_set_secret(&ctx, key, keylen);
		retval = hotp_verify(&ctx, argv[4]);
		printf("%d\n", retval);
		return 0;
	} else if (strcmp(func, "totp_verify") == 0 && argc >= 6) {
		totp_ctx ctx;
		totp_init(&ctx, NULL, atoi(argv[3]), atoi(argv[4]));
		totp_set_secret(&ctx, key, keylen);
		if (argc == 6)
			retval = totp_verify(&ctx, argv[5]);
		else
			retval = totp_verify2(&ctx, argv[5], atol(argv[6]));
		printf("%d\n", retval);
		return 0;
	}
	
	printf("[error] param err\n");
	return 1;
}

